from db.db_base import DBBase, commit, get_one_value, get_one_dict, get_all_dict


__all__ = ["DBBill"]


class DBBill(DBBase):

    @commit
    def add_info(self, bill_id: str, cust_id: str, bill_datetime: str, bill_type: str, bill_amount: str, remark: str) -> None:
        # 加入信息
        return """
        INSERT INTO `Bill` 
        (`bill_id`, `cust_id`, `bill_datetime`, `bill_type`, `bill_amount`, `remark`)
        VALUES (?, ?, ?, ?, ?, ?);
        """

    @commit
    def delete_info(self, bill_id: str):
        # 删除信息
        return """
        DELETE FROM `Bill`
        WHERE `bill_id` = ?;
        """

    @commit
    def update_bill_info(self, cust_id: str, bill_datetime: str, bill_type: str, bill_amount: str, remark: str, bill_id: str):
        # 更新信息
        return """
        UPDATE `Bill`
        SET `cust_id` = ?, `bill_datetime` = ?, `bill_type` = ?, `bill_amount` = ?, `remark` = ? 
        WHERE `bill_id` = ?;
        """

    @get_all_dict("Bill")
    def get_all_info(self, cust_id: str = "*"):
        # 获取所有信息
        return """
            SELECT `bill_id`, `cust_id`, `bill_datetime`, `bill_type`, `bill_amount`, `remark`
            FROM `Bill`
            ORDER BY `bill_datetime` DESC;
            """ \
            if cust_id == "*" else """
            SELECT `bill_id`, `cust_id`, `bill_datetime`, `bill_type`, `bill_amount`, `remark`
            FROM `Bill`
            WHERE `cust_id` = ?
            ORDER BY `bill_datetime` DESC;
            """

    @get_one_dict("Bill")
    def get_info(self, bill_id: str):
        # 获取一条信息
        return """
        SELECT *
        FROM `Bill`
        WHERE `bill_id` = ?;
        """

    @get_one_value
    def get_amount_by_id(self, bill_id: str):
        # 获取回款金额
        return """
        SELECT `bill_amount`
        FROM `Bill`
        WHERE `bill_id` = ?;
        """

    @get_one_value
    def get_cust_id_by_id(self, bill_id: str):
        # 获取客户id
        return """
        SELECT `cust_id`
        FROM `Bill`
        WHERE `bill_id` = ?;
        """

    @get_all_dict("Bill")
    def get_all_info_by_cust(self, cust_id: str):
        return """
        SELECT `bill_id`, `cust_id`, `bill_datetime`, `bill_type`, `bill_amount`, `remark`
        FROM `Bill`
        WHERE `cust_id` = ?;
        """
